# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -tf-xla-compile-device-type="GPU" -tf-enable-soft-placement-on-import=true -o - | FileCheck %s

# Verify main graph was converted to a function, args/rets are mapped correctly,
# and ops in the main graph are retained. In addition, check if subsequent
# functions are converted.

# CHECK:      func @main(%arg0: tensor<*x!tf_type.resource>, %arg1: tensor<*x!tf_type.resource<tensor<3x3x1x32xf32>>>, %arg2: tensor<*xf32>, %arg3: tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<*xf32>)
# CHECK-SAME: _xla_compile_device_type = "GPU"
# CHECK-SAME: allow_soft_placement
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "args_0,args_1,args_2,args_3"
# CHECK-SAME: outputs = "rets_0,rets_1"
# CHECK:          %[[ISLAND_0:.*]], %[[ISLAND_0_control:.*]] = tf_executor.island wraps "tf.Const"
# CHECK:          %[[ISLAND_1:.*]], %[[ISLAND_1_control:.*]] = tf_executor.island wraps "tf.Identity"(%[[ISLAND_0]])
# CHECK:          %[[ISLAND_2:.*]], %[[ISLAND_2_control:.*]] = tf_executor.island wraps "tf.StatefulPartitionedCall"
# CHECK-SAME:       f = @[[FUNC:[a-z0-9]*]]
# CHECK:          tf_executor.fetch %[[ISLAND_1]], %[[ISLAND_2]] : tensor<*xf32>, tensor<*xf32>
# CHECK:      func private @[[FUNC]](%arg0: tensor<*xf32> {tf._user_specified_name = "inputs"}, %arg1: tensor<*x!tf_type.resource>) -> tensor<*xf32>

node {
  name: "args_0"
  op: "_Arg"
  attr {
    key: "T"
    value {
      type: DT_RESOURCE
    }
  }
  attr {
    key: "index"
    value {
      i: 0
    }
  }
}
node {
  name: "args_1"
  op: "_Arg"
  attr {
    key: "T"
    value {
      type: DT_RESOURCE
    }
  }
  attr {
    key: "_handle_dtypes"
    value {
      list {
        type: DT_FLOAT
      }
    }
  }
  attr {
    key: "_handle_shapes"
    value {
      list {
        shape {
          dim {
            size: 3
          }
          dim {
            size: 3
          }
          dim {
            size: 1
          }
          dim {
            size: 32
          }
        }
      }
    }
  }
  attr {
    key: "index"
    value {
      i: 1
    }
  }
}
node {
  name: "args_2"
  op: "_Arg"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "index"
    value {
      i: 2
    }
  }
}
node {
  name: "args_3"
  op: "_Arg"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "_output_shapes"
    value {
      list {
        shape {
          dim {
            size: 2
          }
          dim {
            size: 4
          }
          dim {
            size: 6
          }
          dim {
            size: 8
          }
        }
      }
    }
  }
  attr {
    key: "index"
    value {
      i: 3
    }
  }
}
node {
  name: "const"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 0.0
      }
    }
  }
}
node {
  name: "statefulpartitionedcall"
  op: "StatefulPartitionedCall"
  input: "const"
  input: "args_1"
  attr {
    key: "Tin"
    value {
      list {
        type: DT_FLOAT
        type: DT_RESOURCE
      }
    }
  }
  attr {
    key: "Tout"
    value {
      list {
        type: DT_FLOAT
      }
    }
  }
  attr {
    key: "_gradient_op_type"
    value {
      s: "PartitionedCall-1205"
    }
  }
  attr {
    key: "config"
    value {
      s: ""
    }
  }
  attr {
    key: "config_proto"
    value {
      s: "\n\007\n\003GPU\020\000\n\007\n\003CPU\020\0012\002J\0008\001"
    }
  }
  attr {
    key: "executor_type"
    value {
      s: ""
    }
  }
  attr {
    key: "f"
    value {
      func {
        name: "function"
      }
    }
  }
}
node {
  name: "identity"
  op: "Identity"
  input: "const"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "rets_0"
  op: "_Retval"
  input: "identity"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "index"
    value {
      i: 0
    }
  }
}
node {
  name: "rets_1"
  op: "_Retval"
  input: "statefulpartitionedcall"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "index"
    value {
      i: 1
    }
  }
}
library {
  function {
    signature {
      name: "function"
      input_arg {
        name: "inputs"
        type: DT_FLOAT
      }
      input_arg {
        name: "statefulpartitionedcall_args_1"
        type: DT_RESOURCE
      }
      output_arg {
        name: "identity"
        type: DT_FLOAT
      }
      is_stateful: true
    }
    node_def {
      name: "Identity"
      op: "Identity"
      input: "inputs"
      attr {
        key: "T"
        value {
          type: DT_FLOAT
        }
      }
    }
    ret {
      key: "identity"
      value: "Identity:output:0"
    }
    arg_attr {
      key: 0
      value {
        attr {
          key: "_user_specified_name"
          value {
            s: "inputs"
          }
        }
      }
    }
    arg_attr {
      key: 1
      value {
      }
    }
  }
}
versions {
  producer: 121
}
